-
Notifications
You must be signed in to change notification settings - Fork 31k
[Attn Masks] Non-vmap default for attention masks
#41852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
WIP][Masking] Non-vmap default for attention masksAttn Masks] Non-vmap default for attention masks
| return cache | ||
|
|
||
|
|
||
| def sdpa_mask_without_vmap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer needed as vmap was the reason we needed this workaround in the first place
| NOTE: It is important to keep an index-based version for non-vmap expansion. | ||
| """ | ||
| return q_idx.new_ones((), dtype=torch.bool) | ||
| return q_idx >= 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As noted above, for non-vmap we need this as index based version
| causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) | ||
| return causal_mask | ||
|
|
||
| attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I encountered issues with the inplace version where we'd need a clone (e.g. when using swa). This is safer
Non-vmap creation of masks. These work with all our base masks and we only default back to vmap when using patterns we cannot guarantee (i.e. additional and/or masks).
Note:
Fixes #41639
cc @jiqing-feng @IlyasMoutawwakil